使用PyTorch 2.0加速Transformer:训练推理均拿下!
设为星标,干货直达!
目前Transformer已经成为各个领域(文本,图像,语音)最常用的模型架构,最近刚发布的PyTorch 2.0也进一步对Transformer模块进行了优化,以支持Tranformer结构模型的高效训练和推理。具体来说,PyTorch 2.0在torch.nn.functional
中引入了一个新的函数:torch.nn.functional.scaled_dot_product_attention,这里简称为SPDA,这个SPDA背后实现了高性能的kernels,所以你可以直接使用SPDA来进行训练和推理的减速。
这里我们可以简单看一下这个SPDA这个函数的签名和参数说明:
torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
) → Tensor:
"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape :math:`(N, ..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes causal attention masking and errors if both attn_mask and is_causal
are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
"""
pass
SPDA实现了attention模块最核心的部分(缩放的点乘注意力),这个函数等价于以下代码:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scaleattn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V
这个函数也已经嵌入了PyTorch现有的Transformer API中,这就是说你在构建模型时直接使用torch.nn.MultiheadAttention 和torch.nn.TransformerEncoderLayer模块就可以看到SPDA带来的性能加速。当然,如果你需要定制化功能,那么你可以直接用这个函数来创建自己的attention模块。
SPDA之所以能带来性能的加速,主要是它背后已经实现了优化的kernels,目前SPDA支持三种kernels:
sdpa_flash:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
sdpa_mem_eff: Memory-Efficient Attention
sdpa_math:A PyTorch implementation defined in C++
其中sdpa_flash支持在SM80+架构的GPUs上使用FP16精度训练和推理,而sdpa_mem_eff支持在大部分GPUs上采用FP16和FP32精度训练和推理。如果上述两个kernel都不支持的话,那么就只能采用sdpa_math了,它是直接基于C++的通用实现。默认情况下,这三个kernel都是开启的,当你调用SDPA时,它将根据你的输入选择一个最优的kernel来进行执行。
大部分情况下,我们不需要关注背后具体所选择的kernel,因为它背后已经做了最优的选择。但是如果你想显式控制所使用的kernel,那么可以采用torch.backends.cuda.sdp_kernel()来关闭具体的kernels,它是一个上下文管理器,比如我们要关闭sdpa_math,那么可以这样调用:
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(enable_math=False):
F.scaled_dot_product_attention(query, key, value)
由于sdpa_math被关闭,那么此时系统只能从sdpa_flash和sdpa_mem_eff这个两个kernel进行选择了。当你关闭两个kernel,那么就等同于直接选择使用剩下的那个kernel来进行实现了,比如下面的代码就相当于显式采用sdpa_mem_eff这个kernel了:
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
F.scaled_dot_product_attention(query, key, value)
不过,如果你当前的平台不支持这个kernel,那么将会报错:
RuntimeError: No available kernel. Aborting execution.
这里我们可以使用sdp_kernel这个工具来比较不同的kernels下的计算时间,具体的代码如下:
import torchimport torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.nn.functional as F
# Lets define a helpful benchmarking function:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
device = "cuda"
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
# Helpful arg mapper
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
在V100机器上的运行结果如下所示:
The default implementation runs in 6569.854 microseconds
The math implementation runs in 16091.686 microseconds
<timeit-src>:6: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:527.)
<timeit-src>:6: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
<timeit-src>:6: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:529.)
<timeit-src>:6: UserWarning: Flash attention only supports sm75 and sm8x gpu architectures. Attempting to run on a sm 7.0 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:352.)
FlashAttention is not supported. See warnings for reasons.
The memory efficient implementation runs in 6595.339 microseconds
好吧,V100卡属于sm 7.0,不支持Flash attention,但是我们可以看到默认采用的kernel是sdpd_mem_eff,它相比sdpd_math,速度提升非常明显(6ms vs 16ms)。当我们把机器换成A100后,运行结果如下所示:
The default implementation runs in 2831.521 microseconds
The math implementation runs in 7001.696 microseconds
The flash attention implementation runs in 2829.635 microseconds
The memory efficient implementation runs in 3011.410 microseconds
A100卡上是支持Flash attention,而且默认的实现方式是sdpa_flash,此时运行时间最短,A100比V100快了2倍多。
最后,我们再来看一下具体的实例,那就是基于SPDA对diffusers中的stable diffusion进行加速,目前diffusers中已经实现了基于scaled_dot_product_attention的AttnProcessor2_0:
class AttnProcessor2_0:def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
这里我们以stable diffusion 1.5为例,首先我们将attention processor设置为默认的CrossAttnProcessor:
import torchfrom diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0, CrossAttnProcessor
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(CrossAttnProcessor())
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
这里在V100上运行的时间大约是3.6s(A100下是1.9s),显存最大占用约5.9GB。然后,我们将attention processor替换为AttnProcessor2_0:
pipe.unet.set_attn_processor(CAttnProcessor2_0())加速后的运行时间大约是3s(A100下是1.6s),显存最大占用为4.7GB,可以看到我们不仅实现了加速,而且显存消耗也减少了。
另外,PyTorch 2.0也引入了torch.compile()来对模型进行加速,这里我们也可以在SPDA的基础上使用它来进一步来加速:
import torchfrom diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0, CrossAttnProcessor
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
"cuda"
)
pipe.unet.set_attn_processor(AttnProcessor2_0()) # 其实默认会采用这个
pipe.unet = torch.compile(pipe.unet)
batch_size = 8
prompt = "A photo of an astronaut riding a horse on marse."
images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
这里我在batch_size=8下,跑出来运行时间大约是16s(A100下是6.6s),而只采用SPDA的版本运行时间约17s(A100下是7.3s),还是有一定的加速效果的(不过V100相比A100还是太弱了)。
注意,我们这里的比较并不是严谨的,其实PyTorch官方也已经进行了系统的评测,具体可以见博客Accelerated Diffusers with PyTorch 2.0。
参考
https://huggingface.co/docs/diffusers/v0.13.0/en/optimization/torch2.0
https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
https://pytorch.org/blog/accelerated-diffusers-pt-20/
https://pytorch.org/blog/accelerated-pytorch-2/
推荐阅读
辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!
机器学习算法工程师
一个用心的公众号